package cn.dfun.sample.flink.udftest

import cn.dfun.sample.flink.apitest.SensorReading
import org.apache.flink.streaming.api.TimeCharacteristic
import org.apache.flink.streaming.api.functions.timestamps.BoundedOutOfOrdernessTimestampExtractor
import org.apache.flink.streaming.api.scala._
import org.apache.flink.streaming.api.windowing.time.Time
import org.apache.flink.table.api.EnvironmentSettings
import org.apache.flink.table.functions.{AggregateFunction, ScalarFunction, TableFunction}
import org.apache.flink.table.api.scala._
import org.apache.flink.types.Row

object AggregateFunctionTest {
  def main(args: Array[String]): Unit = {
    val env = StreamExecutionEnvironment.getExecutionEnvironment
    env.setParallelism(1)
    env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
    val settings = EnvironmentSettings.newInstance()
      .useBlinkPlanner()
      .inStreamingMode()
      .build()
    val tableEnv = StreamTableEnvironment.create(env, settings)
    val inputPath = "C:\\wor\\flink-sample\\src\\main\\resources\\sensor"
    val inputStream= env.readTextFile(inputPath)

    val dataStream = inputStream
      .map(data => {
        var arr = data.split(",")
        SensorReading(arr(0), arr(1).toLong, arr(2).toDouble)
      })
      .assignTimestampsAndWatermarks(new BoundedOutOfOrdernessTimestampExtractor[SensorReading](Time.seconds(1)) {
        override def extractTimestamp(element: SensorReading): Long = element.timestamp * 1000L
      })
    val sensorTable = tableEnv.fromDataStream(dataStream, 'id, 'temperature, 'timestamp.rowtime as 'ts)
    val avgTemp = new AvgTemp()
    val resultTable = sensorTable
        .groupBy('id)
        .aggregate(avgTemp('temperature) as 'avgTemp)
        .select('id, 'avgTemp)

    // sql
    tableEnv.createTemporaryView("sensor", sensorTable)
    tableEnv.registerFunction("avgTemp", avgTemp)
    val resultSqlTable = tableEnv.sqlQuery(
      """
        |select id, avgTemp(temperature)
        |from sensor
        |group by id
      """.stripMargin)
    resultTable.toRetractStream[Row].print("result")
    resultSqlTable.toRetractStream[Row].print("sql")
    env.execute("aggregate function test")
  }
}

// 定义一个类用于表示聚合的状态
class AvgTempAcc {
  var sum: Double = 0.0
  var count: Int = 0
}

// 自定义聚合函数
// 求每个传感器温度平均值
// 状态保存所有温度的和与个数
class AvgTemp extends AggregateFunction[Double, AvgTempAcc] {

  override def getValue(acc: AvgTempAcc): Double = acc.sum / acc.count

  override def createAccumulator(): AvgTempAcc = new AvgTempAcc()

  // 还需要实现一个计算函数
  def accumulate(acc: AvgTempAcc, temp: Double): Unit = {
      acc.sum += temp
      acc.count += 1
  }
}
